{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "name": "clutter_maskrcnn_train.ipynb",
      "provenance": [],
      "collapsed_sections": [],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.6.4"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "DfPPQ6ztJhv4"
      },
      "source": [
        "# Mask R-CNN for Bin Picking\n",
        "\n",
        "This notebook is adopted from the [TorchVision 0.3 Object Detection finetuning tutorial](https://pytorch.org/tutorials/intermediate/torchvision_tutorial.html).  We will be finetuning a pre-trained [Mask R-CNN](https://arxiv.org/abs/1703.06870) model on a dataset generated from our \"clutter generator\" script.\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "DBIoe_tHTQgV"
      },
      "source": [
        "!pip install cython\n",
        "# Install pycocotools, the version by default in Colab\n",
        "# has a bug fixed in https://github.com/cocodataset/cocoapi/pull/354\n",
        "!pip install -U 'git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI'\n",
        "\n",
        "# Download TorchVision repo to use some files from\n",
        "# references/detection\n",
        "!git clone https://github.com/pytorch/vision.git\n",
        "!cd vision && git checkout v0.3.0\n",
        "!cp vision/references/detection/utils.py ./\n",
        "!cp vision/references/detection/transforms.py ./\n",
        "!cp vision/references/detection/coco_eval.py ./\n",
        "!cp vision/references/detection/engine.py ./\n",
        "!cp vision/references/detection/coco_utils.py ./\n",
        "\n",
        "# Determine if this notebook is currently running as a notebook or a unit test.\n",
        "from IPython import get_ipython\n",
        "running_as_notebook = get_ipython() and hasattr(get_ipython(), 'kernel')\n",
        "\n",
        "# Imports\n",
        "import fnmatch\n",
        "import json\n",
        "import matplotlib.pyplot as plt\n",
        "import multiprocessing\n",
        "import numpy as np\n",
        "import os\n",
        "from PIL import Image\n",
        "from IPython.display import display\n",
        "\n",
        "import torch\n",
        "import torch.utils.data\n",
        "\n",
        "ycb = [\n",
        "    \"003_cracker_box.sdf\", \"004_sugar_box.sdf\", \"005_tomato_soup_can.sdf\",\n",
        "    \"006_mustard_bottle.sdf\", \"009_gelatin_box.sdf\", \"010_potted_meat_can.sdf\"\n",
        "]\n",
        "\n",
        "#drake_reserved_labels = [32765, 32764, 32766, 32767]\n",
        "\n",
        "def colorize_labels(image):\n",
        "    \"\"\"Colorizes labels.\"\"\"\n",
        "    cc = mpl.colors.ColorConverter()\n",
        "    color_cycle = plt.rcParams[\"axes.prop_cycle\"]\n",
        "    colors = np.array([cc.to_rgb(c[\"color\"]) for c in color_cycle])\n",
        "    bg_color = [0, 0, 0]\n",
        "    image = np.squeeze(image)\n",
        "    background = np.zeros(image.shape[:2], dtype=bool)\n",
        "    for label in reserved_labels:\n",
        "        background |= image == int(label)\n",
        "    foreground = image[np.logical_not(background)]\n",
        "    color_image = colors[image % len(colors)]\n",
        "    color_image[background] = bg_color\n",
        "    return color_image\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "XwyE5A8DGtct"
      },
      "source": [
        "# Download our bin-picking dataset\n",
        "\n",
        "It's definitely possible to actually create this dataset on Colab; I've just written a version of the \"clutter_gen\" method from the last chapter that writes the images (and label images) to disk, along with some annotations.  But it takes a non-trivial amount of time to generate 10,000 images. \n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "_DgAgqauIET9"
      },
      "source": [
        "dataset_path = 'clutter_maskrcnn_data'\n",
        "if not os.path.exists(dataset_path):\n",
        "    !wget https://groups.csail.mit.edu/locomotion/clutter_maskrcnn_data.zip .\n",
        "    !unzip -q clutter_maskrcnn_data.zip"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xA8sBvuHNNH1"
      },
      "source": [
        "If you are on colab, go ahead and use the file browser on the left (looks like a drive under the table of contents panel) to click through the .png and .json files to make sure you understand the dataset you've just created!  If you're on a local machine, just browse to the folder."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "C9Ee5NV54Dmj"
      },
      "source": [
        "# Teach pytorch how to load the dataset\n",
        "\n",
        "into the [format expected by Mask R-CNN](https://pytorch.org/docs/stable/torchvision/models.html#torchvision.models.detection.maskrcnn_resnet50_fpn)."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "mTgWtixZTs3X"
      },
      "source": [
        "\n",
        "class BinPickingDataset(torch.utils.data.Dataset):\n",
        "    def __init__(self, root, transforms=None):\n",
        "        self.root = root\n",
        "        self.num_images = len(fnmatch.filter(os.listdir(root),'*.png'))\n",
        "        self.transforms = transforms\n",
        "\n",
        "    def __getitem__(self, idx):\n",
        "        filename_base = os.path.join(self.root, f\"{idx:05d}\")\n",
        "\n",
        "        img = Image.open(filename_base + \".png\").convert(\"RGB\")\n",
        "        mask = np.squeeze(np.load(filename_base + \"_mask.npy\"))\n",
        "\n",
        "        with open(filename_base + \".json\", \"r\") as f:\n",
        "            instance_id_to_class_name = json.load(f)\n",
        "        labels = ycb == instance_id_to_class_name\n",
        "\n",
        "        # instances are encoded as different colors\n",
        "        obj_ids = np.asarray(list(instance_id_to_class_name.keys()))\n",
        "        count = (mask == np.int16(obj_ids)[:, None, None]).sum(axis=2).sum(axis=1)\n",
        "        \n",
        "        # discard objects instances with less than 10 pixels\n",
        "        obj_ids = obj_ids[count >= 10]\n",
        "\n",
        "        labels = [ycb.index(instance_id_to_class_name[id]+\".sdf\") for id in obj_ids]\n",
        "        obj_ids = np.int16(np.asarray(obj_ids))\n",
        "\n",
        "        # split the color-encoded mask into a set of binary masks\n",
        "        masks = mask == obj_ids[:, None, None]\n",
        "\n",
        "        # get bounding box coordinates for each mask\n",
        "        num_objs = len(obj_ids)\n",
        "        boxes = []\n",
        "        for i in range(num_objs):\n",
        "            pos = np.where(masks[i])\n",
        "            xmin = np.min(pos[1])\n",
        "            xmax = np.max(pos[1])\n",
        "            ymin = np.min(pos[0])\n",
        "            ymax = np.max(pos[0])\n",
        "            boxes.append([xmin, ymin, xmax, ymax])\n",
        "\n",
        "        boxes = torch.as_tensor(boxes, dtype=torch.float32)\n",
        "        labels = torch.as_tensor(labels, dtype=torch.int64)\n",
        "        masks = torch.as_tensor(masks, dtype=torch.uint8)\n",
        "\n",
        "        image_id = torch.tensor([idx])\n",
        "        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])\n",
        "        # suppose all instances are not crowd\n",
        "        iscrowd = torch.zeros((num_objs,), dtype=torch.int64)\n",
        "\n",
        "        target = {}\n",
        "        target[\"boxes\"] = boxes\n",
        "        target[\"labels\"] = labels\n",
        "        target[\"masks\"] = masks\n",
        "        target[\"image_id\"] = image_id\n",
        "        target[\"area\"] = area\n",
        "        target[\"iscrowd\"] = iscrowd\n",
        "\n",
        "        if self.transforms is not None:\n",
        "            img, target = self.transforms(img, target)\n",
        "\n",
        "        return img, target\n",
        "\n",
        "    def __len__(self):\n",
        "        return self.num_images\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "J6f3ZOTJ4Km9"
      },
      "source": [
        "Let's check the output of our dataset."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ZEARO4B_ye0s"
      },
      "source": [
        "dataset = BinPickingDataset(dataset_path)\n",
        "dataset[0][0]"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xWA2NXwVhV_C"
      },
      "source": [
        "# Define the network\n",
        "\n",
        "This cell is where the magic begins to happen.  We load a network that is pre-trained on the COCO dataset, then replace the network head with a new (untrained) network with the right number of outputs for our YCB recognition/segmentation task."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "YjNHjVMOyYlH"
      },
      "source": [
        "import torchvision\n",
        "from torchvision.models.detection.faster_rcnn import FastRCNNPredictor\n",
        "from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor\n",
        "\n",
        "      \n",
        "def get_instance_segmentation_model(num_classes):\n",
        "    # load an instance segmentation model pre-trained on COCO\n",
        "    model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)\n",
        "\n",
        "    # get the number of input features for the classifier\n",
        "    in_features = model.roi_heads.box_predictor.cls_score.in_features\n",
        "    # replace the pre-trained head with a new one\n",
        "    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)\n",
        "\n",
        "    # now get the number of input features for the mask classifier\n",
        "    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels\n",
        "    hidden_layer = 256\n",
        "    # and replace the mask predictor with a new one\n",
        "    model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask,\n",
        "                                                       hidden_layer,\n",
        "                                                       num_classes)\n",
        "\n",
        "    return model"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-WXLwePV5ieP"
      },
      "source": [
        "That's it, this will make model be ready to be trained and evaluated on our custom dataset.\n",
        "\n",
        "# Transforms\n",
        "\n",
        "Let's write some helper functions for data augmentation / transformation, which leverages the functions in torchvision `refereces/detection`. \n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "l79ivkwKy357"
      },
      "source": [
        "from engine import train_one_epoch, evaluate\n",
        "import utils\n",
        "import transforms as T\n",
        "\n",
        "def get_transform(train):\n",
        "    transforms = []\n",
        "    # converts the image, a PIL image, into a PyTorch Tensor\n",
        "    transforms.append(T.ToTensor())\n",
        "    if train:\n",
        "        # during training, randomly flip the training images\n",
        "        # and ground-truth for data augmentation\n",
        "        transforms.append(T.RandomHorizontalFlip(0.5))\n",
        "    return T.Compose(transforms)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FzCLqiZk-sjf"
      },
      "source": [
        "Note that we do not need to add a mean/std normalization nor image rescaling in the data transforms, as those are handled internally by the Mask R-CNN model."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3YFJGJxk6XEs"
      },
      "source": [
        "# Putting everything together\n",
        "\n",
        "We now have the dataset class, the models and the data transforms. Let's instantiate them"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "a5dGaIezze3y"
      },
      "source": [
        "# use our dataset and defined transformations\n",
        "dataset = BinPickingDataset(dataset_path, get_transform(train=True))\n",
        "dataset_test = BinPickingDataset(dataset_path, get_transform(train=False))\n",
        "\n",
        "# split the dataset in train and test set\n",
        "torch.manual_seed(1)\n",
        "indices = torch.randperm(len(dataset)).tolist()\n",
        "dataset = torch.utils.data.Subset(dataset, indices[:-50])\n",
        "dataset_test = torch.utils.data.Subset(dataset_test, indices[-50:])\n",
        "\n",
        "# define training and validation data loaders\n",
        "data_loader = torch.utils.data.DataLoader(\n",
        "    dataset, batch_size=2, shuffle=True, num_workers=4,\n",
        "    collate_fn=utils.collate_fn)\n",
        "\n",
        "data_loader_test = torch.utils.data.DataLoader(\n",
        "    dataset_test, batch_size=1, shuffle=False, num_workers=4,\n",
        "    collate_fn=utils.collate_fn)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "L5yvZUprj4ZN"
      },
      "source": [
        "Now let's instantiate the model and the optimizer"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "zoenkCj18C4h"
      },
      "source": [
        "device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')\n",
        "\n",
        "num_classes = len(ycb)+1\n",
        "\n",
        "# get the model using our helper function\n",
        "model = get_instance_segmentation_model(num_classes)\n",
        "# move model to the right device\n",
        "model.to(device)\n",
        "\n",
        "# construct an optimizer\n",
        "params = [p for p in model.parameters() if p.requires_grad]\n",
        "optimizer = torch.optim.SGD(params, lr=0.005,\n",
        "                            momentum=0.9, weight_decay=0.0005)\n",
        "\n",
        "# and a learning rate scheduler which decreases the learning rate by\n",
        "# 10x every 3 epochs\n",
        "lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,\n",
        "                                               step_size=3,\n",
        "                                               gamma=0.1)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "XAd56lt4kDxc"
      },
      "source": [
        "And now let's train the model for 10 epochs, evaluating at the end of every epoch."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "at-h4OWK0aoc"
      },
      "source": [
        "# let's train it for 10 epochs\n",
        "num_epochs = 10\n",
        "\n",
        "for epoch in range(num_epochs):\n",
        "    # train for one epoch, printing every 10 iterations\n",
        "    train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq=10)\n",
        "    # update the learning rate\n",
        "    lr_scheduler.step()\n",
        "    # evaluate on the test dataset\n",
        "    evaluate(model, data_loader_test, device=device)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "XXTyZhCScUTI"
      },
      "source": [
        "If you're going to leave this running for a bit, I recommend scheduling the following cell to run immediately (so that you don't lose your work)."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "vUJXn15pGzRj"
      },
      "source": [
        "torch.save(model.state_dict(), 'clutter_maskrcnn_model.pt')\n",
        "\n",
        "from google.colab import files\n",
        "files.download('clutter_maskrcnn_model.pt') "
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Z6mYGFLxkO8F"
      },
      "source": [
        "Now that training has finished, let's have a look at what it actually predicts in a test image"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "YHwIdxH76uPj"
      },
      "source": [
        "# pick one image from the test set\n",
        "img, _ = dataset_test[0]\n",
        "# put the model in evaluation mode\n",
        "model.eval()\n",
        "with torch.no_grad():\n",
        "    prediction = model([img.to(device)])"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "DmN602iKsuey"
      },
      "source": [
        "Printing the prediction shows that we have a list of dictionaries. Each element of the list corresponds to a different image. As we have a single image, there is a single dictionary in the list.\n",
        "The dictionary contains the predictions for the image we passed. In this case, we can see that it contains `boxes`, `labels`, `masks` and `scores` as fields."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Lkmb3qUu6zw3"
      },
      "source": [
        "prediction"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "RwT21rzotFbH"
      },
      "source": [
        "Let's inspect the image and the predicted segmentation masks.\n",
        "\n",
        "For that, we need to convert the image, which has been rescaled to 0-1 and had the channels flipped so that we have it in `[C, H, W]` format."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "bpqN9t1u7B2J"
      },
      "source": [
        "Image.fromarray(img.mul(255).permute(1, 2, 0).byte().numpy())"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "M58J3O9OtT1G"
      },
      "source": [
        "And let's now visualize the top predicted segmentation mask. The masks are predicted as `[N, 1, H, W]`, where `N` is the number of predictions, and are probability maps between 0-1."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "5v5S3bm07SO1"
      },
      "source": [
        "Image.fromarray(prediction[0]['masks'][0, 0].mul(255).byte().cpu().numpy())"
      ],
      "execution_count": null,
      "outputs": []
    }
  ]
}